Skip to content

[ROCm] Fix biased wgrad with fp32 gradient accumulation#634

Open
XinyuJiangCMU wants to merge 5 commits into
ROCm:devfrom
XinyuJiangCMU:rocm-wgrad-bgrad-dbias-fix-v2
Open

[ROCm] Fix biased wgrad with fp32 gradient accumulation#634
XinyuJiangCMU wants to merge 5 commits into
ROCm:devfrom
XinyuJiangCMU:rocm-wgrad-bgrad-dbias-fix-v2

Conversation

@XinyuJiangCMU

Copy link
Copy Markdown

Problem

On ROCm, hipBLASLt cannot find a suitable algorithm for an fp32 weight gradient GEMM with fused bias gradient computation.

This causes training with --add-qkv-bias and --accumulate-allreduce-grads-in-fp32 to fail with:

RuntimeError: Unable to find any suitable algorithms

Fix

Run the weight gradient GEMM without the fused bias gradient and compute the bias gradient separately by summing grad_output.

The fix is implemented in general_gemm, covering delayed weight gradient execution and other callers using the same path. CUDA behavior is unchanged.

Testing

Verified on MI350X:

  • The isolated reproduction passes.

  • Qwen2.5-0.5B GSM8K training passes the previously failing backward step.

  • Re-enabled the ROCm wgrad numerics tests previously skipped by grouped GEMM change 434:

    • test_linear_accuracy_delay_wgrad_compute
    • test_layernorm_linear_accuracy_delay_wgrad_compute
    • test_layernorm_mlp_accuracy_delay_wgrad_compute

    All three use general_gemm.

Result:

132 passed, 0 skipped, 0 failed

XinyuJiangCMU and others added 5 commits June 18, 2026 04:41
On ROCm, hipBLASLt has no algorithm for a bf16 -> fp32-accumulate wgrad
GEMM that also fuses the bias-gradient (BGRADB) epilogue: the heuristic
returns zero algorithms and the GEMM raises "Unable to find any suitable
algorithms". This hits any LayerNormLinear with bias (e.g. Qwen2.5 QKV
with add-qkv-bias) when training with fp32 gradient accumulation
(--accumulate-allreduce-grads-in-fp32).

When wgrad is accumulated into an fp32 main_grad on ROCm, skip the fused
dbias and reduce grad_bias separately (grad_output.sum over tokens in
fp32, cast to bias dtype) -- mathematically identical to the BGRADB
epilogue. CUDA and all other paths are unchanged.

Co-Authored-By: Jessica Jiang <jessicajiang324@gmail.com>
Signed-off-by: Xinyu Jiang <xinyuj2@andrew.cmu.edu>
Move the BGRADB-unfuse workaround from the per-module LayerNormLinear backward
up to general_gemm, the single chokepoint every wgrad path funnels through.
This covers Linear, LayerNormLinear, LayerNormMLP and the delayed-wgrad store
in one place, and fixes the delayed-wgrad path that the per-module version
dropped the bias gradient on. CUDA, the forward bias-add path and fp8/fp4 are
untouched.

Co-Authored-By: Jessica Jiang <jessicajiang324@gmail.com>
Signed-off-by: Xinyu Jiang <xinyuj2@andrew.cmu.edu>
Signed-off-by: Xinyu Jiang <xinyuj2@andrew.cmu.edu>
The hipBLASLt "no suitable algorithm" failure for the fused bias-grad (BGRADB) epilogue is driven by the fp32 output dtype, independent of accumulate, so the split must also cover the non-accumulating (e.g. first-microbatch) wgrad. Also exclude gelu, whose bias-grad is not a plain grad_output sum. Re-enable the ROCm numerics test that was skipped for this case.

Co-authored-by: Zhiyao Jiang <jessicajiang324@gmail.com>
Signed-off-by: Xinyu Jiang <xinyuj2@andrew.cmu.edu>
The hipBLASLt "no suitable algorithm" failure for the fused bias-grad (BGRADB) epilogue is driven by the fp32 output dtype, independent of accumulate, so the split must also cover non-accumulating (e.g. first-microbatch) wgrad. Also exclude gelu, whose bias-grad is not a plain grad_output sum. Re-enable the Linear / LayerNormLinear / LayerNormMLP wgrad numerics tests skipped for this case; GroupedLinear routes through general_grouped_gemm and stays skipped.

Co-authored-by: Zhiyao Jiang <jessicajiang324@gmail.com>
Signed-off-by: Xinyu Jiang <xinyuj2@andrew.cmu.edu>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-level 3 CI test level 3

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants